Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improve flights.* dataset reproducibility #645

Merged
merged 29 commits into from
Dec 20, 2024
Merged

Conversation

dangotbanned
Copy link
Member

@dangotbanned dangotbanned commented Dec 10, 2024

Addresses:

Description

This PR follows up the work in (#626, #642) to provide a more self-contained and reproducible version of (https://github.com/vega/vega-datasets/blob/5eaa256486deec59f3c4441d0abb568688ff5a81/scripts/flights.py).

Noteworthy changes

Specifications for each dataset are defined in _data/flights.toml

flights.toml

[[specs]]
# This spec won't be part of the final PR.
# Using it to demonstrate the deviation from [ISO_8601](https://en.wikipedia.org/wiki/ISO_8601)
range = [2001-01-01, 2001-03-31]
n_rows = 1_000
suffix = ".csv"
dt_format = "iso"
[[specs]]
start = 2001-01-01
end = 2001-03-31
n_rows = 2_000
suffix = ".json"
dt_format = "%Y/%m/%d %H:%M"
[[specs]]
start = 2001-01-01
end = 2001-03-31
n_rows = 5_000
suffix = ".json"
dt_format = "%Y/%m/%d %H:%M"
[[specs]]
start = 2001-01-01
end = 2001-03-31
n_rows = 10_000
suffix = ".json"
dt_format = "%Y/%m/%d %H:%M"
[[specs]]
start = 2001-01-01
end = 2001-03-31
n_rows = 20_000
suffix = ".json"
dt_format = "%Y/%m/%d %H:%M"
[[specs]]
start = 2001-01-01
end = 2001-03-31
n_rows = 200_000
suffix = ".json"
dt_format = "decimal"
columns = ["delay", "distance", "time"]
[[specs]]
start = 2001-01-01
end = 2001-03-31
n_rows = 200_000
suffix = ".arrow"
dt_format = "decimal"
columns = ["delay", "distance", "time"]
[[specs]]
start = 2001-01-01
end = 2001-06-30
n_rows = 3_000_000
suffix = ".parquet"


Manually visiting transtats and downloading monthly data is no longer required.

Prior method

Data Source:
1) Visit https://www.transtats.bts.gov/DL_SelectFields.asp?gnoyr_VQ=FGJ&QO_fu146_anzr=b0-gvzr
(link valid as of November 2024)
2) Download prezipped files, one per month.
Input Data Requirements:
- ZIP files containing BTS On-Time Performance data CSVs

Any data that is missing becomes an asynchronous request, and is output as .parquet for efficient and fast storage.

All you need to do is run the script.

New method

def download_sources(self) -> None:
"""
Ensure all required source data is saved to ``self.input_dir``.
Any month(s) that are missing will be requested from `transtats`_.
.. _transtats:
https://www.transtats.bts.gov
"""
logger.info("Detecting required sources ...")
if missing := self.missing_stems:
asyncio.run(self._download_sources_async(missing))
logger.info("Successfully downloaded all missing sources.")
else:
logger.info("Sources already downloaded.")

async def _download_sources_async(self, names: Iterable[str], /) -> list[Path]:
"""Request, write missing data."""
session = niquests.AsyncSession(base_url=ROUTE_ZIP)
aws = (_request_async(session, name) for name in names)
buffers = await asyncio.gather(*aws)
writes = (_write_zip_to_parquet_async(self.input_dir, buf) for buf in buffers)
return await asyncio.gather(*writes)


Switching from pandas -> polars.

This is for both readability and performance.
The following two functions are roughly equivalent:

flights.process_flights_data

def process_flights_data(
df: pd.DataFrame,
num_rows: Optional[int] = None,
random_seed: int = 42,
datetime_convert: bool = True,
datetime_format: DateTimeFormat = DateTimeFormat.MMDDHHMM,
flag_date_changes: bool = False,
columns: Optional[List[str]] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> pd.DataFrame:
"""Process flight data with specified columns and format."""
# Set random seed for reproducibility
np.random.seed(random_seed)
# Filter cancelled flights and drop NA values first
df = df[~df.Cancelled].dropna(subset=['ArrDelay', 'DepDelay', 'CRSDepTime'])
# Add MMDDHHMM format check
if datetime_format == DateTimeFormat.MMDDHHMM:
unique_years = pd.to_datetime(df['FlightDate']).dt.year.unique()
if len(unique_years) > 1:
raise ValueError(
f"MMDDHHMM format cannot be used with data spanning multiple years. "
f"Found data from years: {sorted(unique_years)}. "
"Please use ISO format (-d iso) for multi-year data."
)
# Calculate scheduled minutes and actual minutes for validation
scheduled_minutes = (df['CRSDepTime'].astype(int) // 100 * 60 +
df['CRSDepTime'].astype(int) % 100)
actual_minutes = scheduled_minutes + df['DepDelay'].astype(int)
# Vectorized validation
dep_time_minutes = (df['DepTime'] // 100 * 60 + df['DepTime'] % 100).where(df['DepTime'] != 2400, 0)
# Handle midnight crossing
time_diff = np.abs(dep_time_minutes - actual_minutes % 1440)
time_diff = np.minimum(time_diff, 1440 - time_diff)
# Check consistency with 1-minute tolerance
all_consistent = (~pd.isnull(df['DepTime'])) & (time_diff <= 1)
logging.info(
"All reported departure times are consistent with delays"
if all_consistent.all()
else "Some reported departure times are inconsistent with delays"
)
# Calculate actual datetime
flight_dates = pd.to_datetime(df['FlightDate'])
df['actual_datetime'] = (flight_dates +
pd.to_timedelta(scheduled_minutes, unit='m') +
pd.to_timedelta(df['DepDelay'], unit='m'))
# Special handling for 2400 DepTime (midnight next day)
mask_2400 = df['DepTime'] == 2400
if mask_2400.any():
df.loc[mask_2400, 'actual_datetime'] = flight_dates[mask_2400] + pd.Timedelta(days=1)
# Store scheduled dates before filtering
df['scheduled_date'] = flight_dates.dt.date
# Filter by date range if specified
if start_date or end_date:
if start_date:
start_dt = pd.to_datetime(start_date)
df = df[df['actual_datetime'].dt.date >= start_dt.date()]
if end_date:
end_dt = pd.to_datetime(end_date)
df = df[df['actual_datetime'].dt.date <= end_dt.date()]
if len(df) == 0:
logging.warning(f"No flights found within specified date range")
else:
logging.info(f"Found {len(df)} flights within specified date range")
# Calculate date changes if requested
if flag_date_changes:
df['date_changed'] = df['scheduled_date'] != df['actual_datetime'].dt.date
df = df.drop('scheduled_date', axis=1) # Clean up temporary column
# Create result DataFrame with all possible columns
result = pd.DataFrame({
'date': df['actual_datetime'],
'delay': df['ArrDelay'].astype(int),
'distance': df['Distance'].astype(int),
'origin': df['Origin'],
'destination': df['Dest'],
'ScheduledFlightDate': df['FlightDate'],
'ScheduledFlightTime': df['CRSDepTime'],
'DepDelay': df['DepDelay'].astype(int)
})
if flag_date_changes:
result['date_changed'] = df['date_changed']
# Format datetime according to specified format
if datetime_convert:
result = format_datetime(result, datetime_format)
# Handle sampling if requested
if num_rows is not None and num_rows < len(result):
result = result.sample(n=num_rows, random_state=random_seed)
result = result.sort_values('time' if datetime_format == DateTimeFormat.DECIMAL else 'date')
logging.info(f"Randomly sampled {num_rows} rows from {len(df)} total rows")
# Select and order columns if specified
if columns:
validate_columns(columns, datetime_format, flag_date_changes)
result = result[columns]
else:
default_cols = ['time' if datetime_format == DateTimeFormat.DECIMAL else 'date',
'delay', 'distance', 'origin', 'destination']
result = result[default_cols]
return result

flights2.SourceMap.clean

Ignoring the docstring, it is much more concise and IMO a lot easier to understand.

@staticmethod
def clean(ldf: pl.LazyFrame, /) -> pl.LazyFrame:
"""
Fix *known* dataset issues, coerce types, rename columns.
Parameters
----------
ldf
Monthly datasets, concatenated as a single table.
Notes
-----
- Rows containing cancelled flights or null values are dropped (~3.16%)
- Non compliant* `ISO-8601`_ times are corrected
*Invalid midnight representation prior to `ISO-8601-1-2019-Amd-1-2022`_
**Input schema**:
{
"FlightDate": datetime.date,
"CRSDepTime": str,
"DepTime": str,
"DepDelay": float,
"ArrDelay": float,
"Distance": float,
"Origin": str,
"Dest": str,
"Cancelled": float,
}
**Output schema**:
{
"date": datetime.datetime,
"delay": int,
"distance": int,
"origin": str,
"destination": str,
"ScheduledFlightDate": datetime.date,
"ScheduledFlightTime": datetime.time,
"DepDelay": int,
}
.. _ISO-8601:
https://en.wikipedia.org/wiki/ISO_8601
.. _ISO-8601-1-2019-Amd-1-2022:
https://cdn.standards.iteh.ai/samples/81801/f527872a9fe34281ae3a4af8e730f3f8/ISO-8601-1-2019-Amd-1-2022.pdf#page=8
"""
cancelled = col("Cancelled").cast(bool)
flight_date = col("FlightDate")
dep_time = col("DepTime")
times = cs.ends_with("DepTime")
wrap_midnight = times.str.replace("2400", "0000").str.to_time("%H%M")
datetime = flight_date.dt.combine(dep_time)
flight_date_corrected = (
pl.when(dep_time == pl.time(0, 0, 0, 0))
.then(datetime.dt.offset_by("1d"))
.otherwise(datetime)
)
return (
ldf.filter(
~pl.any_horizontal(cancelled, dep_time == "", cs.float().is_null())
)
.with_columns(wrap_midnight, cs.float().cast(int))
.select(
flight_date_corrected.alias("date"),
col("ArrDelay").alias("delay"),
col("Distance", "Origin").name.to_lowercase(),
col("Dest").alias("destination"),
flight_date.alias("ScheduledFlightDate"),
col("CRSDepTime").alias("ScheduledFlightTime"),
"DepDelay",
)
)

Tasks

_data/flights.toml Outdated Show resolved Hide resolved
- (Locally) Using and adapted version of https://github.com/vega/altair/blob/main/pyproject.toml
- Can add to the repo later if desired (#643 (comment))
scripts/flights2.py Outdated Show resolved Hide resolved
scripts/flights2.py Outdated Show resolved Hide resolved
scripts/flights2.py Outdated Show resolved Hide resolved
scripts/flights2.py Outdated Show resolved Hide resolved
scripts/flights2.py Outdated Show resolved Hide resolved
scripts/flights2.py Outdated Show resolved Hide resolved
- `_write_zip_to_parquet` has a detailed doc breaking down the benefits
- Overall, much faster and requires less storage
Doc is now very explicit in what *cleaning* means here

Resolves #645 (comment)
- short summary for `Flights.run`
- examples for each constructor in class-level doc

#645 (comment)
scripts/flights2.py Outdated Show resolved Hide resolved
- moves `Flights.scan_sources` -> `SourceMap.from_specs`
- rename `SourceMap.add_dependency` -> `SourceMap.add_spec`

#645 (comment)
- Makes the doc visible in more places
- Added a description of what `None` means per-extension
- Fixed typo `"ISO 6801"` -> `"ISO 8601"`
@dangotbanned dangotbanned changed the title feat(DRAFT): Improve flights.* dataset reproducibility feat: Improve flights.* dataset reproducibility Dec 18, 2024
@dangotbanned dangotbanned marked this pull request as ready for review December 18, 2024 14:49
@domoritz
Copy link
Member

Should I wait for the last todos to be done?

@dangotbanned
Copy link
Member Author

dangotbanned commented Dec 18, 2024

Should I wait for the last todos to be done?

Ah @domoritz I wasn't expecting you so soon!

Yeah I'm done now, was just filling out the description to make reviewing easier.

Edit

You should be able to simply run the script locally, without arguments.
If we wanted to run this in CI, I'd probably need to add a static seed for polars.DataFrame.sample

@domoritz domoritz requested a review from dsmedia December 18, 2024 16:41
@domoritz
Copy link
Member

Great. @dsmedia can you do a detailed review? I will do a rough check. We don't need to run the scripts on the CI.

@dangotbanned
Copy link
Member Author

#645 (comment)

Thanks @domoritz

@dsmedia
Copy link
Collaborator

dsmedia commented Dec 19, 2024

This looks great @dangotbanned. The pattern of downloading data from this API, transforming it, and generating clean sample datasets of various sizes is something that I could see being useful for analysis of this flights dataset well beyond the purposes of this repo.

Given that broader potential, one suggestion that might add some flexibility to the valid row counts while keeping things clean: Instead of restricting row counts to specific numbers,

type Rows = Literal[
    1_000,
    2_000,
    5_000,
    10_000,
    20_000,
    100_000,
    200_000,
    500_000,
    1_000_000,
    3_000_000,
    5_000_000,
    10_000_000,
    100_000_000,
    500_000_000,
    1_000_000_000,
]
"""Number of rows to include in the output."""

we could allow any multiple of 1,000 for thousands and any multiple of 1,000,000 for millions. For example:

def validate_row_count(n: int) -> bool:
    if n >= 1_000_000:
        return n % 1_000_000 == 0  # Must be exact millions
    return n % 1_000 == 0  # Must be exact thousands

This would let users choose numbers like:

  • 66,000 -> "flights-66k.parquet"
  • 123,000 -> "flights-123k.parquet"
  • 676,000,000 -> "flights-676m.parquet"

The nice thing about this approach is that:

  1. It maintains unambiguous filenames (each valid number maps to exactly one filename)
  2. Filenames stay clean and readable (no decimals)
  3. Gives users more flexibility in choosing sample sizes for different use cases (e.g., testing, benchmarking, demos)
  4. The validation rule is simple to understand - just use exact thousands/millions

If this script gets adopted by other projects (which I could definitely see happening), this flexibility would be particularly valuable. Different projects might need different sample sizes - some might want 66k rows for a specific benchmark, others might need 676m rows to test large-scale processing. Having this flexibility while maintaining clean, predictable filenames would make the script more broadly useful.

@dangotbanned
Copy link
Member Author

dangotbanned commented Dec 19, 2024

#645 (comment)

Thanks for taking a look @dsmedia!

Given that broader potential, one suggestion that might add some flexibility to the valid row counts while keeping things clean: Instead of restricting row counts to specific numbers,

I didn't do a job documenting this well here.
The fixed values defined in Rows lean more towards hints; than a validation rule.
The actual logic in Spec.name is really just checking n_rows >= 1_000:

Spec.name

@property
def name(self) -> str:
"""
Encodes a short form of ``n_rows`` into the file name.
Examples
--------
Note that the final name depends on ``suffix``:
| n_rows | stem |
| -------------- | ------------- |
| 10_000 | "flights-10k" |
| 1_000_000 | "flights-1m" |
| 12_000_000_000 | "flights-12b" |
"""
frac = self.n_rows // 1_000
if frac >= 1_000_000:
s = f"{frac // 1_000_000}b"
elif frac >= 1_000:
s = f"{frac // 1_000}m"
elif frac >= 1:
s = f"{frac}k"
else:
raise TypeError(self.n_rows)
return f"{self._name_prefix}{s}{self.suffix}"


So the 3 examples you gave all work currently (at runtime), but 999 fails:

Example code block

from scripts.flights2 import Spec, DateRange

date_range = DateRange((2001, 1), (2001, 12))
spec_0 = Spec(date_range, 66_000, ".parquet")
spec_1 = Spec(date_range, 123_000, ".parquet")
spec_2 = Spec(date_range, 676_000_000, ".parquet")
spec_bad = Spec(date_range, 999, ".parquet")

print(spec_0.name)
print(spec_1.name)
print(spec_2.name)
print(spec_bad.name)

Traceback

flights-66k.parquet
flights-123k.parquet
flights-676m.parquet
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 12
     10 print(spec_1.name)
     11 print(spec_2.name)
---> 12 print(spec_bad.name)

File ../vega-datasets/scripts/flights2.py:415, in Spec.name(self)
    413     s = f"{frac}k"
    414 else:
--> 415     raise TypeError(self.n_rows)
    416 return f"{self._name_prefix}{s}{self.suffix}"

TypeError: 999

Admittedly though, all 4 of these equally trigger a type checker warning (statically) - which is not very intuitive:

Screenshot

image


The resolved environment for this script has a transient dependency which has some examples of runtime Annotated validation:

https://github.com/annotated-types/annotated-types/blob/b60ad7bff90019c7de3547df1c8e3586eb328423/annotated_types/__init__.py#L204-L225

I explored this a little with the DateTimeFormat type.
There, I've used a very loose approximation of what would be accepted by chrono

DateTimeFormat

def is_chrono_str(s: Any) -> TypeIs[_ChronoFormat]:
return s == "%Y/%m/%d %H:%M" or (isinstance(s, str) and s.startswith("%"))
def is_datetime_format(s: Any) -> TypeIs[DateTimeFormat]:
return s in {"iso", "iso:strict", "decimal"} or is_chrono_str(s) or s is None
type _ChronoFormat = Literal["%Y/%m/%d %H:%M"] | Annotated[LiteralString, is_chrono_str]
"""https://docs.rs/chrono/latest/chrono/format/strftime/index.html"""
type DateTimeFormat = Literal["iso", "iso:strict", "decimal"] | _ChronoFormat | None
"""
Anything that is resolvable to a date/time column transform.
Notes
-----
When not provided:
- {``.arrow``, ``.parquet``} preserve temporal data types on write
- ``.json`` defaults to **"iso"**
- ``.csv`` defaults to **"iso:strict"**
Examples
--------
Each example will use the same input datetime:
from datetime import datetime
datetime(2020, 3, 1, 6, 30, 0)
**"iso"**, **"iso:strict"**: variants of `ISO 8601`_ used in `pl.Expr.dt.to_string`_:
"2020-03-01 06:30:00.000000"
"2020-03-01T06:30:00.000000"
**"decimal"**: represents **time only** with fractional minutes::
6.5 # stored as a float
A format string using `chrono`_ specifiers:
"%Y/%m/%d %H:%M" -> "2020/03/01 06:30"
"%s" -> "1583044200" # UNIX timestamp
"%c" -> "Sun Mar 1 06:30:00 2020"
"%T" -> "06:30:00"
"%Y-%B-%d" -> "2020-March-01"
"%e-%b-%Y" -> " 1-Mar-2020"
.. _ISO 8601:
https://en.wikipedia.org/wiki/ISO_8601
.. _pl.Expr.dt.to_string:
https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.dt.to_string.html
.. _chrono:
https://docs.rs/chrono/latest/chrono/format/strftime/index.html
"""

Question

@dsmedia would you be happy if I rewrite n_rows: Rows more in the style of dt_format: DateTimeFormat?

This would look like:

  • Some more content in the docstring
  • Replacing Literal[...] -> Annotated[int, is_rows]
  • Adding a runtime check for Spec.n_rows like Spec.dt_format
    • With the TypeError message explaining the constraint

Updated

@dsmedia hopefully I've addressed your concerns in (8ec3adf)

Since the validation logic started getting a bit unwieldy in Spec.__init__, I've moved that to Spec._validate.
There's also checks for valid columns and file extension in there as well

@dangotbanned dangotbanned merged commit a88ff4c into main Dec 20, 2024
3 checks passed
@dangotbanned dangotbanned deleted the flights-repro branch December 20, 2024 11:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants